/********************************************************************************
 * Copyright 2017 The Robotics Group, The Maersk Mc-Kinney Moller Institute,
 * Faculty of Engineering, University of Southern Denmark
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 ********************************************************************************/

#include "PolynomialSolver.hpp"

#include "Polynomial.hpp"

#include <rw/core/macros.hpp>
#include <rw/math/Constants.hpp>

#include <complex>
#include <limits>
#include <stddef.h>

using namespace rw::math;

PolynomialSolver::PolynomialSolver (const Polynomial<>& polynomial) :
    _polynomial (polynomial.order ()), _derivative (NULL), _guess (0), _iterations (10)
{
    for (std::size_t i = 0; i <= polynomial.order (); i++)
        _polynomial[i] = polynomial[i];
}

PolynomialSolver::PolynomialSolver (const Polynomial< std::complex< double > >& polynomial) :
    _polynomial (polynomial), _derivative (NULL), _guess (0), _iterations (10)
{}

PolynomialSolver::~PolynomialSolver ()
{
    if (_derivative != NULL)
        delete _derivative;
}

void PolynomialSolver::setInitialGuess (std::complex< double > guess)
{
    _guess = guess;
}

std::vector< double > PolynomialSolver::getRealSolutions (const double EPS)
{
    std::vector< std::complex< double > > sol = getSolutions (EPS);
    std::vector< double > res;

    for (std::size_t i = 0; i < sol.size (); i++) {
        std::complex< double > x = sol[i];
        if (std::fabs (x.imag ()) <= 2. * EPS * std::fabs (x.real ())) {
            res.push_back (x.real ());
        }
    }

    return res;
}

std::vector< std::complex< double > > PolynomialSolver::getSolutions (const double EPS)
{
    std::vector< std::complex< double > > res;

    // Remove highest order coefficients if close to zero
    PolynomialSolver* polZ = this;
    PolynomialSolver* pol  = this;
    if (pol->_polynomial.order () <= 0)
        return res;
    if (std::abs (pol->_polynomial[pol->_polynomial.order ()]) <= EPS) {
        std::size_t zeroIdx = 0;
        for (std::size_t i = 1; i < pol->_polynomial.order (); i++) {
            if (std::abs (pol->_polynomial[pol->_polynomial.order () - i]) <= EPS) {
                zeroIdx = i;
            }
            else {
                break;
            }
        }
        Polynomial< std::complex< double > > newPol (pol->_polynomial.order () - 1 - zeroIdx);
        for (std::size_t i = 0; i < pol->_polynomial.order () - zeroIdx; i++) {
            newPol[i] = pol->_polynomial[i];
        }
        if (newPol.order () <= 0)
            return res;
        pol  = new PolynomialSolver (newPol);
        polZ = new PolynomialSolver (newPol);
    }

    // Reduce order if lowest order coefficients is close to zero
    if (std::abs (pol->_polynomial[0]) <= EPS) {
        std::size_t zeroIdx = 0;
        for (std::size_t i = 1; i <= pol->_polynomial.order (); i++) {
            if (std::abs (pol->_polynomial[i]) <= EPS) {
                zeroIdx = i;
            }
            else {
                break;
            }
        }
        for (std::size_t i = 0; i <= zeroIdx; i++)
            res.push_back (0.);
        Polynomial< std::complex< double > > newPol (pol->_polynomial.order () - 1 - zeroIdx);
        for (std::size_t i = 0; i < pol->_polynomial.order () - zeroIdx; i++) {
            newPol[i] = pol->_polynomial[i + zeroIdx + 1];
        }
        if (pol != this)
            delete pol;
        if (polZ != this)
            delete polZ;
        if (newPol.order () <= 0)
            return res;
        pol  = new PolynomialSolver (newPol);
        polZ = new PolynomialSolver (newPol);
    }

    //	// Normalize polynomial
    //	{
    //	    const Polynomial<std::complex<double> >& p = pol->_polynomial;
    //	    const double scale = std::abs(p[p.order()]);
    //        Polynomial<std::complex<double> > newPol(p.order());
    //        for (std::size_t i = 0; i <= p.order(); i++) {
    //            newPol[i] = p[i]/scale;
    //        }
    //        if (pol != this)
    //            delete pol;
    //        if (polZ != this)
    //            delete polZ;
    //        pol = new PolynomialSolver(newPol);
    //        polZ = new PolynomialSolver(newPol);
    //	}

    // Try to solve analytically if the polynomial is simple enough
    bool analyticSolutions = pol->analyticalSolutions (res, EPS);

    if (!analyticSolutions) {
        // Find tentative roots
        for (std::size_t i = 0; i < polZ->_polynomial.order () && !analyticSolutions; i++) {
            // No analytical solution can be found yet, so use Laguerre to find a root
            std::complex< double > x = _guess;
            x                        = pol->laguerre (x);
            res.push_back (x);
            PolynomialSolver* old = pol;
            pol                   = new PolynomialSolver (pol->_polynomial.deflate (x));
            if (old != this)
                delete old;
            analyticSolutions = pol->analyticalSolutions (res, EPS);
        }

        // Polish
        for (std::size_t i = 0; i < polZ->_polynomial.order (); i++) {
            res[i] = polZ->laguerre (res[i]);
        }
    }

    if (pol != this)
        delete pol;
    if (polZ != this)
        delete polZ;

    return res;
}

void PolynomialSolver::setLaguerreIterations (unsigned int iterations)
{
    _iterations = iterations;
}

bool PolynomialSolver::analyticalSolutions (std::vector< std::complex< double > >& res,
                                            const double EPS) const
{
    // Check if polynomial is of the trivial format a x^n + b = 0
    bool trivial = true;
    for (std::size_t i = 1; i < _polynomial.order () && trivial; i++) {
        if (std::abs (_polynomial[i]) > EPS)
            trivial = false;
    }
    if (trivial) {
        const Polynomial< std::complex< double > >& p = _polynomial;
        const double order                            = static_cast< double > (p.order ());
        const std::complex< double > val              = -p[0] / p[p.order ()];
        const double z                                = std::pow (std::abs (val), 1. / order);
        const double arg                              = std::arg (val);
        if (p.order () == 1) {
            res.push_back (-p[0] / p[1]);
        }
        else if (p.order () % 2 == 0) {
            for (std::size_t i = 0; i < p.order () / 2; i++) {
                const double real = z * std::cos ((arg + Pi * 2 * i) / order);
                const double imag = z * std::sin ((arg + Pi * 2 * i) / order);
                res.push_back (std::complex< double > (real, imag));
                res.push_back (std::complex< double > (-real, -imag));
            }
        }
        else {
            for (std::size_t i = 0; i < p.order (); i++) {
                const double real = z * std::cos ((arg + Pi * 2 * i) / order);
                const double imag = z * std::sin ((arg + Pi * 2 * i) / order);
                res.push_back (std::complex< double > (real, imag));
            }
        }
        return true;
    }
    else if (_polynomial.order () <=
             3) {    // Check if polynomial is degree 3 or less and solve analytically
        const Polynomial< std::complex< double > >& p = _polynomial;
        // For order 0 there will be no solution
        // Order 1 is a trivial case
        if (_polynomial.order () == 2) {
            const std::complex< double > K = std::sqrt (p[1] * p[1] - p[2] * p[0] * 4.);
            res.push_back ((-p[1] + K) / (p[2] * 2.));
            res.push_back ((-p[1] - K) / (p[2] * 2.));
        }
        else if (_polynomial.order () == 3) {
            const double scale                 = std::abs (p[3]);
            const std::complex< double > a     = p[3] / scale;
            const std::complex< double > b     = p[2] / scale;
            const std::complex< double > c     = p[1] / scale;
            const std::complex< double > d     = p[0] / scale;
            const std::complex< double > a3    = a * 3.;
            const std::complex< double > a3a3  = a3 * a3;
            const std::complex< double > a4    = a * 4.;
            const std::complex< double > bc    = b * c;
            const std::complex< double > bPow3 = b * b * b;
            const std::complex< double > cPow3 = c * c * c;
            const std::complex< double > D     = a * bc * d * 18. - bPow3 * d * 4. + bc * bc -
                                             a4 * cPow3 - a3a3 * d * d * 3.;    // discriminant
            const std::complex< double > D0 = b * b - c * a3;
            if (std::abs (D) < EPS) {
                if (std::abs (D0) < EPS) {
                    const std::complex< double > root = -b / a3;
                    res.push_back (root);
                    res.push_back (root);
                    res.push_back (root);
                }
                else {
                    const std::complex< double > root = (a3 * d * 3. - bc) / (D0 * 2.);
                    res.push_back (root);
                    res.push_back (root);
                    res.push_back ((a4 * bc - a3a3 * d - bPow3) / (D0 * a));
                }
            }
            else {
                const std::complex< double > D1   = bPow3 * 2. - a3 * bc * 3. + a3a3 * d * 3.;
                const std::complex< double > K    = D1 * D1 - D0 * D0 * D0 * 4.;
                const std::complex< double > CinM = (D1 + std::sqrt (K)) / 2.;
                const std::complex< double > CinP = (D1 - std::sqrt (K)) / 2.;
                const std::complex< double > Cin =
                    (std::abs (CinM) > std::abs (CinP)) ? CinM : CinP;
                const double z   = std::pow (std::abs (Cin), 1. / 3.);
                const double arg = std::arg (Cin);
                const std::complex< double > C0 (z * std::cos (arg / 3), z * std::sin (arg / 3));
                const std::complex< double > C1 =
                    C0 * std::complex< double > (-1. / 2., std::sqrt (3.) / 2.);
                const std::complex< double > C2 =
                    C0 * std::complex< double > (-1. / 2., -std::sqrt (3.) / 2.);
                res.push_back (-b / a3 - (C0 + D0 / C0) / a3);
                res.push_back (-b / a3 - (C1 + D0 / C1) / a3);
                res.push_back (-b / a3 - (C2 + D0 / C2) / a3);
            }
        }
        return true;
    }
    return false;
}

std::complex< double > PolynomialSolver::laguerre (std::complex< double > x) const
{
    const double n = 1.0 + _polynomial.order ();
    for (std::size_t i = 1; i <= _iterations; i++) {
        double err;
        std::complex< double > val = _polynomial.evaluate (x, err);
        if (std::abs (val) <= err)
            return x;
        const std::vector< std::complex< double > > der = _polynomial.evaluateDerivatives (x, 2);
        const std::complex< double > G                  = der[1] / val;
        const std::complex< double > Gsq                = G * G;
        const std::complex< double > H                  = Gsq - der[2] / val;
        const std::complex< double > sq                 = std::sqrt ((n - 2) * ((n - 1) * H - Gsq));
        const std::complex< double > denPlus            = G + sq;
        const std::complex< double > denMinus           = G - sq;
        const std::complex< double > den =
            (std::abs (denPlus) > std::abs (denMinus)) ? denPlus : denMinus;
        if (std::abs (den) == 0)
            RW_THROW ("PolynomialSolver::laguerre failed du to zero denominator!");
        const std::complex< double > a = (n - 1) / den;
        std::complex< double > xNew    = x - a;
        if (xNew == x)
            return xNew;
        x = xNew;
    }

    // RW_THROW("PolynomialSolver::laguerre failed to find solution - too many iterations.");
    RW_WARN ("PolynomialSolver::laguerre failed to find solution - too many iterations.");
    return 0;
}
